from torch import nn
from mamba_ssm import Mamba
import torch
from einops import rearrange
import numpy as np
import cv2
import pywt
import math

from timm.models.layers import trunc_normal_
from torch.autograd import Function

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import torch
from torch import nn
from torch.nn import functional as F
from einops.layers.torch import Rearrange


class PVMLayer(nn.Module):
    def __init__(self, input_dim, output_dim, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.norm = nn.LayerNorm(input_dim)
        self.mamba = Mamba(
            d_model=input_dim // 4,  # Model dimension d_model
            d_state=d_state,  # SSM state expansion factor
            d_conv=d_conv,  # Local convolution width
            expand=expand,  # Block expansion factor
        )
        self.proj = nn.Linear(input_dim, output_dim)
        self.skip_scale = nn.Parameter(torch.ones(1))

    def forward(self, x):
        if x.dtype == torch.float16:
            x = x.type(torch.float32)
        B, C = x.shape[:2]
        assert C == self.input_dim
        n_tokens = x.shape[2:].numel()
        img_dims = x.shape[2:]
        x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
        x_norm = self.norm(x_flat)

        x1, x2, x3, x4 = torch.chunk(x_norm, 4, dim=2)
        x_mamba1 = self.mamba(x1) + self.skip_scale * x1
        x_mamba2 = self.mamba(x2) + self.skip_scale * x2
        x_mamba3 = self.mamba(x3) + self.skip_scale * x3
        x_mamba4 = self.mamba(x4) + self.skip_scale * x4
        x_mamba = torch.cat([x_mamba1, x_mamba2, x_mamba3, x_mamba4], dim=2)

        x_mamba = self.norm(x_mamba)
        x_mamba = self.proj(x_mamba)
        out = x_mamba.transpose(-1, -2).reshape(B, self.output_dim, *img_dims)
        return out


class Channel_Att_Bridge(nn.Module):
    def __init__(self, c_list, split_att='fc'):
        super().__init__()
        c_list_sum = sum(c_list) - c_list[-1]
        self.split_att = split_att
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.get_all_att = nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False)
        self.att1 = nn.Linear(c_list_sum, c_list[0]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[0], 1)
        self.att2 = nn.Linear(c_list_sum, c_list[1]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[1], 1)
        self.att3 = nn.Linear(c_list_sum, c_list[2]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[2], 1)
        self.att4 = nn.Linear(c_list_sum, c_list[3]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[3], 1)
        self.att5 = nn.Linear(c_list_sum, c_list[4]) if split_att == 'fc' else nn.Conv1d(c_list_sum, c_list[4], 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, t1, t2, t3, t4, t5):
        att = torch.cat((self.avgpool(t1),
                         self.avgpool(t2),
                         self.avgpool(t3),
                         self.avgpool(t4),
                         self.avgpool(t5)), dim=1)
        att = self.get_all_att(att.squeeze(-1).transpose(-1, -2))
        if self.split_att != 'fc':
            att = att.transpose(-1, -2)
        att1 = self.sigmoid(self.att1(att))
        att2 = self.sigmoid(self.att2(att))
        att3 = self.sigmoid(self.att3(att))
        att4 = self.sigmoid(self.att4(att))
        att5 = self.sigmoid(self.att5(att))
        if self.split_att == 'fc':
            att1 = att1.transpose(-1, -2).unsqueeze(-1).expand_as(t1)
            att2 = att2.transpose(-1, -2).unsqueeze(-1).expand_as(t2)
            att3 = att3.transpose(-1, -2).unsqueeze(-1).expand_as(t3)
            att4 = att4.transpose(-1, -2).unsqueeze(-1).expand_as(t4)
            att5 = att5.transpose(-1, -2).unsqueeze(-1).expand_as(t5)
        else:
            att1 = att1.unsqueeze(-1).expand_as(t1)
            att2 = att2.unsqueeze(-1).expand_as(t2)
            att3 = att3.unsqueeze(-1).expand_as(t3)
            att4 = att4.unsqueeze(-1).expand_as(t4)
            att5 = att5.unsqueeze(-1).expand_as(t5)

        return att1, att2, att3, att4, att5


class Spatial_Att_Bridge(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared_conv2d = nn.Sequential(nn.Conv2d(2, 1, 7, stride=1, padding=9, dilation=3),
                                           nn.Sigmoid())

    def forward(self, t1, t2, t3, t4, t5):
        t_list = [t1, t2, t3, t4, t5]
        att_list = []
        for t in t_list:
            avg_out = torch.mean(t, dim=1, keepdim=True)
            max_out, _ = torch.max(t, dim=1, keepdim=True)
            att = torch.cat([avg_out, max_out], dim=1)
            att = self.shared_conv2d(att)
            att_list.append(att)
        return att_list[0], att_list[1], att_list[2], att_list[3], att_list[4]


class SC_Att_Bridge(nn.Module):
    def __init__(self, c_list, split_att='fc'):
        super().__init__()

        self.catt = Channel_Att_Bridge(c_list, split_att=split_att)
        self.satt = Spatial_Att_Bridge()

    def forward(self, t1, t2, t3, t4, t5):
        r1, r2, r3, r4, r5 = t1, t2, t3, t4, t5

        satt1, satt2, satt3, satt4, satt5 = self.satt(t1, t2, t3, t4, t5)
        t1, t2, t3, t4, t5 = satt1 * t1, satt2 * t2, satt3 * t3, satt4 * t4, satt5 * t5

        r1_, r2_, r3_, r4_, r5_ = t1, t2, t3, t4, t5
        t1, t2, t3, t4, t5 = t1 + r1, t2 + r2, t3 + r3, t4 + r4, t5 + r5

        catt1, catt2, catt3, catt4, catt5 = self.catt(t1, t2, t3, t4, t5)
        t1, t2, t3, t4, t5 = catt1 * t1, catt2 * t2, catt3 * t3, catt4 * t4, catt5 * t5

        return t1 + r1_, t2 + r2_, t3 + r3_, t4 + r4_, t5 + r5_


class UltraLight_VM_UNet(nn.Module):

    def __init__(self, num_classes=1, input_channels=3, c_list=[8, 16, 24, 32, 48, 64],
                 split_att='fc', bridge=True):
        super().__init__()
        self.bridge = bridge
        # self.xxx = yyy()
        self.encoder1 = nn.Sequential(
            nn.Conv2d(input_channels, c_list[0], 3, stride=1, padding=1),
        )
        self.encoder2 = nn.Sequential(
            nn.Conv2d(c_list[0], c_list[1], 3, stride=1, padding=1),
        )
        self.encoder3 = nn.Sequential(
            nn.Conv2d(c_list[1], c_list[2], 3, stride=1, padding=1),
        )
        self.encoder4 = nn.Sequential(
            PVMLayer(input_dim=c_list[2], output_dim=c_list[3])
        )
        self.encoder5 = nn.Sequential(
            PVMLayer(input_dim=c_list[3], output_dim=c_list[4])
        )
        self.encoder6 = nn.Sequential(
            PVMLayer(input_dim=c_list[4], output_dim=c_list[5])
        )

        img_size = [128, 64, 32, 16, 8]
        num_heads_enc = [2, 4, 8, 16, 24]
        dims = [8, 16, 24, 32, 48]

        if bridge:
            self.scab = SC_Att_Bridge(c_list, split_att)
            self.fet_att = [fet_block(img_size[i], in_dim=dims[i], num_heads=num_heads_enc[i]).to('cuda') for i in
                            range(5)]

        self.decoder1 = nn.Sequential(
            PVMLayer(input_dim=c_list[5], output_dim=c_list[4])
        )
        self.decoder2 = nn.Sequential(
            PVMLayer(input_dim=c_list[4], output_dim=c_list[3])
        )
        self.decoder3 = nn.Sequential(
            PVMLayer(input_dim=c_list[3], output_dim=c_list[2])
        )
        self.decoder4 = nn.Sequential(
            nn.Conv2d(c_list[2], c_list[1], 3, stride=1, padding=1),
        )
        self.decoder5 = nn.Sequential(
            nn.Conv2d(c_list[1], c_list[0], 3, stride=1, padding=1),
        )
        self.ebn1 = nn.GroupNorm(4, c_list[0])
        self.ebn2 = nn.GroupNorm(4, c_list[1])
        self.ebn3 = nn.GroupNorm(4, c_list[2])
        self.ebn4 = nn.GroupNorm(4, c_list[3])
        self.ebn5 = nn.GroupNorm(4, c_list[4])
        self.dbn1 = nn.GroupNorm(4, c_list[4])
        self.dbn2 = nn.GroupNorm(4, c_list[3])
        self.dbn3 = nn.GroupNorm(4, c_list[2])
        self.dbn4 = nn.GroupNorm(4, c_list[1])
        self.dbn5 = nn.GroupNorm(4, c_list[0])

        self.final = nn.Conv2d(c_list[0], num_classes, kernel_size=1)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Conv1d):
            n = m.kernel_size[0] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):

        out = F.gelu(F.max_pool2d(self.ebn1(self.encoder1(x)), 2, 2))
        t1 = out  # b, c0, H/2, W/2

        out = F.gelu(F.max_pool2d(self.ebn2(self.encoder2(out)), 2, 2))
        t2 = out  # b, c1, H/4, W/4

        # t2 = self.xxx(t2)
        # print(t2.shape)
        # exit()

        out = F.gelu(F.max_pool2d(self.ebn3(self.encoder3(out)), 2, 2))
        t3 = out  # b, c2, H/8, W/8

        out = F.gelu(F.max_pool2d(self.ebn4(self.encoder4(out)), 2, 2))
        t4 = out  # b, c3, H/16, W/16

        out = F.gelu(F.max_pool2d(self.ebn5(self.encoder5(out)), 2, 2))
        t5 = out  # b, c4, H/32, W/32

        if self.bridge:
            # t1, t2, t3, t4, t5 = self.scab(t1, t2, t3, t4, t5)
            t1 = self.fet_att[0](t1)
            t2 = self.fet_att[1](t2)
            t3 = self.fet_att[2](t3)
            t4 = self.fet_att[3](t4)
            t5 = self.fet_att[4](t5)

        out = F.gelu(self.encoder6(out))  # b, c5, H/32, W/32

        out5 = F.gelu(self.dbn1(self.decoder1(out)))  # b, c4, H/32, W/32
        out5 = torch.add(out5, t5)  # b, c4, H/32, W/32

        out4 = F.gelu(F.interpolate(self.dbn2(self.decoder2(out5)), scale_factor=(2, 2), mode='bilinear',
                                    align_corners=True))  # b, c3, H/16, W/16
        out4 = torch.add(out4, t4)  # b, c3, H/16, W/16

        out3 = F.gelu(F.interpolate(self.dbn3(self.decoder3(out4)), scale_factor=(2, 2), mode='bilinear',
                                    align_corners=True))  # b, c2, H/8, W/8
        out3 = torch.add(out3, t3)  # b, c2, H/8, W/8

        out2 = F.gelu(F.interpolate(self.dbn4(self.decoder4(out3)), scale_factor=(2, 2), mode='bilinear',
                                    align_corners=True))  # b, c1, H/4, W/4
        out2 = torch.add(out2, t2)  # b, c1, H/4, W/4

        out1 = F.gelu(F.interpolate(self.dbn5(self.decoder5(out2)), scale_factor=(2, 2), mode='bilinear',
                                    align_corners=True))  # b, c0, H/2, W/2
        out1 = torch.add(out1, t1)  # b, c0, H/2, W/2

        out0 = F.interpolate(self.final(out1), scale_factor=(2, 2), mode='bilinear',
                             align_corners=True)  # b, num_class, H, W

        return torch.sigmoid(out0)


###############################
import torch
from einops import rearrange
import numpy as np
import cv2
import pywt
import math

from timm.models.layers import trunc_normal_
from torch.autograd import Function

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import torch
from torch import nn
from torch.nn import functional as F
from einops.layers.torch import Rearrange


class DWConv(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)

    def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
        B, N, C = x.shape
        tx = x.transpose(1, 2).view(B, C, H, W)
        conv_x = self.dwconv(tx)
        return conv_x.flatten(2).transpose(1, 2)


class MixFFN(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.fc1 = nn.Linear(c1, c2)
        self.dwconv = DWConv(c2)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(c2, c1)

    def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
        ax = self.act(self.dwconv(self.fc1(x), H, W))
        out = self.fc2(ax)
        return out


class MixFFN_skip(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.fc1 = nn.Linear(c1, c2)
        self.dwconv = DWConv(c2)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(c2, c1)
        self.norm1 = nn.LayerNorm(c2)
        self.norm2 = nn.LayerNorm(c2)
        self.norm3 = nn.LayerNorm(c2)

    def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
        ax = self.act(self.norm1(self.dwconv(self.fc1(x), H, W) + self.fc1(x)))
        out = self.fc2(ax)
        return out


class MLP_FFN(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.fc1 = nn.Linear(c1, c2)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(c2, c1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


class MixD_FFN(nn.Module):
    def __init__(self, c1, c2, fuse_mode="add"):
        super().__init__()
        self.fc1 = nn.Linear(c1, c2)
        self.dwconv = DWConv(c2)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(c2, c1) if fuse_mode == "add" else nn.Linear(c2 * 2, c1)
        self.fuse_mode = fuse_mode

    def forward(self, x, H, W):
        ax = self.dwconv(self.fc1(x), H, W)
        fuse = self.act(ax + self.fc1(x)) if self.fuse_mode == "add" else self.act(torch.cat([ax, self.fc1(x)], 2))
        out = self.fc2(ax)
        return out


class OverlapPatchEmbeddings(nn.Module):
    def __init__(self, img_size=224, patch_size=7, stride=4, padding=1, in_ch=3, dim=768):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_ch, dim, patch_size, stride, padding)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        px = self.proj(x)
        _, _, H, W = px.shape
        fx = px.flatten(2).transpose(1, 2)
        nfx = self.norm(fx)
        return nfx, H, W


class MLP(nn.Module):
    def __init__(self, dim, embed_dim):
        super().__init__()
        self.proj = nn.Linear(dim, embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.flatten(2).transpose(1, 2)
        return self.proj(x)


class ConvModule(nn.Module):
    def __init__(self, c1, c2, k):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.activate = nn.ReLU(True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.activate(self.bn(self.conv(x)))


class DWT_Function(Function):
    @staticmethod
    def forward(ctx, x, w_ll, w_lh, w_hl, w_hh):
        x = x.contiguous()
        ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh)
        ctx.shape = x.shape

        dim = x.shape[1]
        x_ll = torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride=2, groups=dim)
        x_lh = torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride=2, groups=dim)
        x_hl = torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride=2, groups=dim)
        x_hh = torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride=2, groups=dim)
        x = torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1)
        return x

    @staticmethod
    def backward(ctx, dx):
        if ctx.needs_input_grad[0]:
            w_ll, w_lh, w_hl, w_hh = ctx.saved_tensors
            B, C, H, W = ctx.shape
            dx = dx.view(B, 4, -1, H // 2, W // 2)

            dx = dx.transpose(1, 2).reshape(B, -1, H // 2, W // 2)
            filters = torch.cat([w_ll, w_lh, w_hl, w_hh], dim=0).repeat(C, 1, 1, 1)
            dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=2, groups=C)

        return dx, None, None, None, None


class DWT_2D(nn.Module):
    def __init__(self, wave):
        super(DWT_2D, self).__init__()
        w = pywt.Wavelet(wave)
        dec_hi = torch.Tensor(w.dec_hi[::-1])
        dec_lo = torch.Tensor(w.dec_lo[::-1])

        w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1)
        w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1)
        w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1)
        w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)

        self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0))
        self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0))
        self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0))
        self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0))

        self.w_ll = self.w_ll.to(dtype=torch.float32)
        self.w_lh = self.w_lh.to(dtype=torch.float32)
        self.w_hl = self.w_hl.to(dtype=torch.float32)
        self.w_hh = self.w_hh.to(dtype=torch.float32)

    def forward(self, x):
        return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh)


class EfficientAttention(nn.Module):
    """
        input  -> x:[B, D, H, W]
        output ->   [B, D, H, W]

        in_channels:    int -> Embedding Dimension
        key_channels:   int -> Key Embedding Dimension,   Best: (in_channels)
        value_channels: int -> Value Embedding Dimension, Best: (in_channels or in_channels//2)
        head_count:     int -> It divides the embedding dimension by the head_count and process each part individually
    """

    def __init__(self, in_channels, key_channels, value_channels, head_count=1):
        super().__init__()
        self.in_channels = in_channels
        self.key_channels = key_channels
        self.head_count = head_count
        self.value_channels = value_channels

        self.keys = nn.Conv2d(in_channels, key_channels, 1)
        self.queries = nn.Conv2d(in_channels, key_channels, 1)
        self.values = nn.Conv2d(in_channels, value_channels, 1)
        self.reprojection = nn.Conv2d(value_channels, in_channels, 1)

    def forward(self, input_):
        n, _, h, w = input_.size()

        keys = self.keys(input_).reshape((n, self.key_channels, h * w))
        queries = self.queries(input_).reshape(n, self.key_channels, h * w)
        values = self.values(input_).reshape((n, self.value_channels, h * w))

        head_key_channels = self.key_channels // self.head_count
        head_value_channels = self.value_channels // self.head_count

        attended_values = []
        for i in range(self.head_count):
            key = F.softmax(keys[
                            :,
                            i * head_key_channels: (i + 1) * head_key_channels,
                            :
                            ], dim=2)

            query = F.softmax(queries[
                              :,
                              i * head_key_channels: (i + 1) * head_key_channels,
                              :
                              ], dim=1)

            value = values[
                    :,
                    i * head_value_channels: (i + 1) * head_value_channels,
                    :
                    ]

            context = key @ value.transpose(1, 2)  # dk*dv
            attended_value = (context.transpose(1, 2) @ query).reshape(n, head_value_channels, h, w)  # n*dv
            attended_values.append(attended_value)

        aggregated_values = torch.cat(attended_values, dim=1)
        attention = self.reprojection(aggregated_values)

        return attention


class Attention(nn.Module):
    def __init__(self, dim, num_heads, bridge=False):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.bridge = False
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, dim * 2)
        self.proj = nn.Linear(dim, dim)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W, q=None):
        B, N, C = x.shape

        if self.bridge:
            q = q.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        else:
            q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x


def get_kernel_gussian(kernel_size, Sigma=1, in_channels=64):
    kernel_weights = cv2.getGaussianKernel(ksize=kernel_size, sigma=Sigma)
    kernel_weights = kernel_weights * kernel_weights.T
    kernel_weights = np.repeat(kernel_weights[None, ...], in_channels, axis=0)[:, None, ...]

    return kernel_weights


class FET(nn.Module):
    def __init__(self, dim, num_heads, sr_ratio, bridge=False, L_num=3):
        super().__init__()

        assert L_num > 0, "Laplacian number (L_num) must be in [1, 2, 3]"
        self.L_num = L_num  # Laplacian number (L0, L1, ...)

        self.bridge = bridge
        self.num_heads = num_heads
        self.dim = dim
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.sr_ratio = sr_ratio

        self.dwt = DWT_2D(wave='haar')
        self.reduce = nn.Sequential(
            nn.Conv2d(dim, dim // 4, kernel_size=1, padding=0, stride=1),
            nn.BatchNorm2d(dim // 4),
            nn.ReLU(inplace=True),
        )
        self.filter = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, groups=1),
            nn.BatchNorm2d(dim),
            nn.ReLU(inplace=True),
        )
        self.kv_embed = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) if sr_ratio > 1 else nn.Identity()
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * 2)
        )

        self.hf_agg = nn.Conv3d(dim // 4, dim // 4, kernel_size=(3, 1, 1), bias=False, groups=dim // 4)

        # Gaussian Kernel
        ### parameters
        kernet_shapes = [3, 5]
        s_value = np.power(2, 1 / 3)
        sigma = 1.6

        ### Kernel weights for Laplacian pyramid
        if self.L_num > 1:
            self.sigma1_kernel = get_kernel_gussian(kernel_size=kernet_shapes[0], Sigma=sigma * np.power(s_value, 1),
                                                    in_channels=dim // 4)
            self.sigma1_kernel = torch.from_numpy(self.sigma1_kernel).float().to(device)

        if self.L_num > 2:
            self.sigma2_kernel = get_kernel_gussian(kernel_size=kernet_shapes[1], Sigma=sigma * np.power(s_value, 2),
                                                    in_channels=dim // 4)
            self.sigma2_kernel = torch.from_numpy(self.sigma2_kernel).float().to(device)

        self.boundary_lvl_agg = nn.Conv3d(dim // 4, dim // 4, kernel_size=(self.L_num, 1, 1), bias=False,
                                          groups=dim // 4)
        self.sig = nn.Sigmoid()
        self.linear_upsample = nn.Linear(dim // 4, dim)
        self.proj = nn.Linear(dim + dim, dim)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W, q=None):
        B, N, C = x.shape

        if self.bridge:
            q = q.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        else:
            q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        x = x.view(B, H, W, C).permute(0, 3, 1, 2)
        x_dwt = self.dwt(self.reduce(x))
        x_dwt_filter = self.filter(x_dwt)

        kv = self.kv_embed(x_dwt_filter).reshape(B, C, -1).permute(0, 2, 1)
        kv = self.kv(kv).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        # Spatial Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        # Efficient Attention
        global_context = F.softmax(k.reshape(B, N // 4, C).transpose(1, 2), dim=2) @ v.reshape(B, N // 4, C)
        out_efficient_att = F.softmax(q.reshape(B, N, C), dim=1) @ global_context

        # Boundary Attention
        x_dwt_hf = Rearrange('b (p c) h w -> b c p h w', p=4)(x_dwt)[:, :, 1:, ...]
        x_hf_agg = self.hf_agg(x_dwt_hf)[:, :, 0, ...]

        G0 = x_hf_agg
        L0 = G0[:, :, None, ...]  # Level 1
        L_layers = [L0, ]

        if self.L_num > 1:
            G1 = F.conv2d(input=x_hf_agg, weight=self.sigma1_kernel, bias=None, padding='same', groups=self.dim // 4)
            L1 = torch.sub(G0, G1)[:, :, None, ...]  # Level 2
            L_layers += [L1]
        if self.L_num > 2:
            G2 = F.conv2d(input=x_hf_agg, weight=self.sigma2_kernel, bias=None, padding='same', groups=self.dim // 4)
            L2 = torch.sub(G1, G2)[:, :, None, ...]  # Level 3
            L_layers += [L2]

        lvl_cat = torch.cat(L_layers, dim=2)

        boundary_lvl_agg = self.boundary_lvl_agg(lvl_cat)[:, :, 0, ...].permute(0, 2, 3, 1)
        boundary_att = self.linear_upsample(boundary_lvl_agg).permute(0, 3, 1, 2)

        # Value and Boundary Attention Summation
        boundary_att = Rearrange('b (n p) h w -> b n (h w) p', n=self.num_heads)(boundary_att)
        boundary_att = self.sig(boundary_att)
        v_sum_boundary_att = v + boundary_att

        # Spatial Attention @ Enhanced Value
        out_spatial_boundary = (attn @ v_sum_boundary_att).transpose(1, 2).reshape(B, N, C)

        # Final Projection
        out = self.proj(torch.cat([out_spatial_boundary, out_efficient_att], dim=-1))

        return out


class FETBlock(nn.Module):
    """
        Input  -> x (Size: (b, (H*W), d)), H, W
        Output -> (b, (H*W), d)
    """

    def __init__(self, in_dim, num_heads, sr_ratio, token_mlp='mix_skip', bottleneck=False, L_num=3):
        super().__init__()
        self.norm1 = nn.LayerNorm(in_dim)

        if bottleneck:
            self.attn = Attention(in_dim, num_heads)
        else:
            self.attn = FET(in_dim, num_heads, sr_ratio, L_num=L_num)

        self.norm2 = nn.LayerNorm(in_dim)
        if token_mlp == 'mix':
            self.mlp = MixFFN(in_dim, int(in_dim * 4))
        elif token_mlp == 'mix_skip':
            self.mlp = MixFFN_skip(in_dim, int(in_dim * 4))
        else:
            self.mlp = MLP_FFN(in_dim, int(in_dim * 4))

    def forward(self, x: torch.Tensor, H, W) -> torch.Tensor:
        x = x + self.attn(self.norm1(x), H, W)
        x = x + self.mlp(self.norm2(x), H, W)
        return x


# Encoder
class Encoder(nn.Module):
    def __init__(self, image_size, in_dim, num_heads, sr_ratio, layers, token_mlp='mix_skip', L_num=3):
        super().__init__()
        patch_sizes = [7, 3, 3, 3]
        strides = [4, 2, 2, 2]
        padding_sizes = [3, 1, 1, 1]

        # patch_embed
        self.patch_embed1 = OverlapPatchEmbeddings(image_size, patch_sizes[0], strides[0], padding_sizes[0], 3,
                                                   in_dim[0])
        self.patch_embed2 = OverlapPatchEmbeddings(image_size // 4, patch_sizes[1], strides[1], padding_sizes[1],
                                                   in_dim[0], in_dim[1])
        self.patch_embed3 = OverlapPatchEmbeddings(image_size // 8, patch_sizes[2], strides[2], padding_sizes[2],
                                                   in_dim[1], in_dim[2])
        self.patch_embed4 = OverlapPatchEmbeddings(image_size // 16, patch_sizes[3], strides[3], padding_sizes[3],
                                                   in_dim[2], in_dim[3])

        # transformer encoder
        self.block1 = nn.ModuleList([
            FETBlock(in_dim[0], num_heads[0], sr_ratio[0], token_mlp, L_num=L_num)
            for _ in range(layers[0])])
        self.norm1 = nn.LayerNorm(in_dim[0])

        self.block2 = nn.ModuleList([
            FETBlock(in_dim[1], num_heads[1], sr_ratio[1], token_mlp, L_num=L_num)
            for _ in range(layers[1])])
        self.norm2 = nn.LayerNorm(in_dim[1])

        self.block3 = nn.ModuleList([
            FETBlock(in_dim[2], num_heads[2], sr_ratio[2], token_mlp, L_num=L_num)
            for _ in range(layers[2])])
        self.norm3 = nn.LayerNorm(in_dim[2])

        self.block4 = nn.ModuleList([
            FETBlock(in_dim[3], num_heads[3], sr_ratio[3], token_mlp, bottleneck=True, L_num=L_num)
            for _ in range(layers[3])])
        self.norm4 = nn.LayerNorm(in_dim[3])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]
        outs = []

        # stage 1
        x, H, W = self.patch_embed1(x)
        for blk in self.block1:
            x = blk(x, H, W)
        x = self.norm1(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 2
        x, H, W = self.patch_embed2(x)
        for blk in self.block2:
            x = blk(x, H, W)
        x = self.norm2(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 3
        x, H, W = self.patch_embed3(x)
        for blk in self.block3:
            x = blk(x, H, W)
        x = self.norm3(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        # stage 4
        x, H, W = self.patch_embed4(x)
        for blk in self.block4:
            x = blk(x, H, W)
        x = self.norm4(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        outs.append(x)

        return outs


# Skip Connection
class SE_1D(nn.Module):
    def __init__(self, in_channels, se_channels):
        super().__init__()

        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(in_channels, se_channels, kernel_size=1),
            nn.GELU(),
            nn.Conv1d(se_channels, in_channels, kernel_size=1),
            nn.Sigmoid())

    def forward(self, x):
        y = self.fc(x)
        return x * y


class QueryGenerator(nn.Module):
    def __init__(self, dim, dim_out):
        super().__init__()
        self.se = SE_1D(dim, dim_out)

    def forward(self, x):
        B, C, _ = x.shape
        x = self.se(x)

        q_lvl1 = x[..., :3136].reshape(B, -1, C)
        q_lvl2 = x[..., 3136:4704].reshape(B, -1, C * 2)
        q_lvl3 = x[..., 4704:5684].reshape(B, -1, C * 5)
        q_lvl4 = x[..., 5684:6076].reshape(B, -1, C * 8)

        return [q_lvl1, q_lvl2, q_lvl3, q_lvl4]


class BridgeLayer(nn.Module):
    def __init__(self, dims, num_heads=[2, 4, 10, 16], sr_ratio=[1, 1, 1, 1], L_num=3):
        super().__init__()

        self.norm1 = nn.LayerNorm(dims)
        self.attn = EfficientAttention(dims, dims, dims)
        self.norm2 = nn.LayerNorm(dims)

        # Global Query Generator
        self.queries = QueryGenerator(dim=dims, dim_out=dims)

        # FET Block
        self.wave_lvl_1 = FET(dims, num_heads[0], sr_ratio[0], bridge=True, L_num=L_num)
        self.wave_lvl_2 = FET(dims * 2, num_heads[1], sr_ratio[1], bridge=True, L_num=L_num)
        self.wave_lvl_3 = FET(dims * 5, num_heads[2], sr_ratio[2], bridge=True, L_num=L_num)
        self.wave_lvl_4 = Attention(dims * 8, num_heads[3], bridge=True)

        # MixFFN
        self.mixffn1 = MixFFN_skip(dims, dims * 4)
        self.mixffn2 = MixFFN_skip(dims * 2, dims * 8)
        self.mixffn3 = MixFFN_skip(dims * 5, dims * 20)
        self.mixffn4 = MixFFN_skip(dims * 8, dims * 32)

    def forward(self, inputs):
        B = inputs[0].shape[0]
        C = 64
        if (type(inputs) == list):
            c1, c2, c3, c4 = inputs
            B, C, _, _ = c1.shape
            c1f = c1.permute(0, 2, 3, 1).reshape(B, -1, C)  # 3136*64
            c2f = c2.permute(0, 2, 3, 1).reshape(B, -1, C)  # 1568*64
            c3f = c3.permute(0, 2, 3, 1).reshape(B, -1, C)  # 980*64
            c4f = c4.permute(0, 2, 3, 1).reshape(B, -1, C)  # 392*64

            inputs = torch.cat([c1f, c2f, c3f, c4f], -2)
        else:
            B, _, C = inputs.shape

        inputs = self.norm1(inputs)
        inputs = Rearrange('b (h w) c -> b c h w', h=124, w=49)(inputs)
        tx1 = inputs + self.attn(inputs)
        tx1 = Rearrange('b c h w -> b (h w) c')(tx1)
        tx = self.norm2(tx1)

        lvl1 = tx[:, :3136, :].reshape(B, -1, C)
        lvl2 = tx[:, 3136:4704, :].reshape(B, -1, C * 2)
        lvl3 = tx[:, 4704:5684, :].reshape(B, -1, C * 5)
        lvl4 = tx[:, 5684:6076, :].reshape(B, -1, C * 8)

        q_lvl1, q_lvl2, q_lvl3, q_lvl4 = self.queries(tx1.permute(0, 2, 1))

        wave_lvl1_att = self.wave_lvl_1(lvl1, 56, 56, q_lvl1)
        wave_lvl2_att = self.wave_lvl_2(lvl2, 28, 28, q_lvl2)
        wave_lvl3_att = self.wave_lvl_3(lvl3, 14, 14, q_lvl3)
        wave_lvl4_att = self.wave_lvl_4(lvl4, 7, 7, q_lvl4)

        m1f = self.mixffn1(wave_lvl1_att, 56, 56).reshape(B, -1, C)
        m2f = self.mixffn2(wave_lvl2_att, 28, 28).reshape(B, -1, C)
        m3f = self.mixffn3(wave_lvl3_att, 14, 14).reshape(B, -1, C)
        m4f = self.mixffn4(wave_lvl4_att, 7, 7).reshape(B, -1, C)

        t1 = torch.cat([m1f, m2f, m3f, m4f], -2)

        tx2 = tx1 + t1

        return tx2


class BridegeBlock(nn.Module):
    def __init__(self, dims, num_heads=[2, 4, 10, 16], sr_ratio=[1, 1, 1, 1], bridge_layers=2, L_num=3):
        super().__init__()

        self.bridge = nn.ModuleList([
            BridgeLayer(dims, num_heads=num_heads, sr_ratio=sr_ratio, L_num=L_num)
            for _ in range(bridge_layers)])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for blk in self.bridge:
            x = blk(x)

        B, _, C = x.shape
        outs = []

        sk1 = x[:, :3136, :].reshape(B, 56, 56, C).permute(0, 3, 1, 2)
        sk2 = x[:, 3136:4704, :].reshape(B, 28, 28, C * 2).permute(0, 3, 1, 2)
        sk3 = x[:, 4704:5684, :].reshape(B, 14, 14, C * 5).permute(0, 3, 1, 2)
        sk4 = x[:, 5684:6076, :].reshape(B, 7, 7, C * 8).permute(0, 3, 1, 2)

        outs.append(sk1)
        outs.append(sk2)
        outs.append(sk3)
        outs.append(sk4)

        return outs


# Decoder
class PatchExpand(nn.Module):
    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
        self.norm = norm_layer(dim // dim_scale)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        x = self.expand(x)

        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)
        x = x.view(B, -1, C // 4)
        x = self.norm(x.clone())

        return x


class FinalPatchExpand_X4(nn.Module):
    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.dim_scale = dim_scale
        self.expand = nn.Linear(dim, 16 * dim, bias=False)
        self.output_dim = dim
        self.norm = norm_layer(self.output_dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        x = self.expand(x)
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)
        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
                      c=C // (self.dim_scale ** 2))
        x = x.view(B, -1, self.output_dim)
        x = self.norm(x.clone())

        return x


class MyDecoderLayer(nn.Module):
    def __init__(self, input_size, in_out_head_sr, token_mlp_mode, n_class=9,
                 norm_layer=nn.LayerNorm, is_last=False, L_num=3):
        super().__init__()
        dims = in_out_head_sr[0]
        out_dim = in_out_head_sr[1]
        num_heads = in_out_head_sr[2]
        sr_ratio = in_out_head_sr[3]

        if not is_last:
            self.concat_linear = nn.Linear(dims * 2, out_dim)
            self.layer_up = PatchExpand(input_resolution=input_size, dim=out_dim, dim_scale=2, norm_layer=norm_layer)
            self.last_layer = None
        else:
            self.concat_linear = nn.Linear(dims * 4, out_dim)
            self.layer_up = FinalPatchExpand_X4(input_resolution=input_size, dim=out_dim, dim_scale=4,
                                                norm_layer=norm_layer)
            self.last_layer = nn.Conv2d(out_dim, n_class, 1)

        self.layer_former_1 = FETBlock(out_dim, num_heads, sr_ratio, token_mlp_mode, L_num=L_num)
        self.layer_former_2 = FETBlock(out_dim, num_heads, sr_ratio, token_mlp_mode, L_num=L_num)

        def init_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.LayerNorm):
                    nn.init.ones_(m.weight)
                    nn.init.zeros_(m.bias)
                elif isinstance(m, nn.Conv2d):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)

        init_weights(self)

    def forward(self, x1, x2=None):
        if x2 is not None:  # skip connection exist
            b, h, w, c = x2.shape
            x2 = x2.view(b, -1, c)
            cat_x = torch.cat([x1, x2], dim=-1)
            cat_linear_x = self.concat_linear(cat_x)
            tran_layer_1 = self.layer_former_1(cat_linear_x, h, w)
            tran_layer_2 = self.layer_former_2(tran_layer_1, h, w)

            if self.last_layer:
                out = self.last_layer(self.layer_up(tran_layer_2).view(b, 4 * h, 4 * w, -1).permute(0, 3, 1, 2))
            else:
                out = self.layer_up(tran_layer_2)
        else:
            out = self.layer_up(x1)
        return out


# Proposed Model
class Model(nn.Module):
    def __init__(self, num_classes=9, num_heads_enc=[2, 4, 10, 16], sr_ratio=[1, 1, 1, 1], num_heads_dec=[16, 8, 4, 1],
                 bridge_layers=1, token_mlp_mode="mix_skip", encoder_L_num=3, bridge_L_num=3, decoder_L_num=3):
        super().__init__()

        # Encoder
        dims, layers = [[64, 128, 320, 512], [2, 2, 2, 2]]
        self.backbone = Encoder(image_size=224, in_dim=dims, num_heads=num_heads_enc, sr_ratio=sr_ratio, layers=layers,
                                token_mlp=token_mlp_mode, L_num=encoder_L_num)

        # Skip Connection
        self.skip_connection = BridegeBlock(dims=dims[0], num_heads=num_heads_enc, sr_ratio=[1, 1, 1, 1],
                                            bridge_layers=bridge_layers, L_num=bridge_L_num)

        # Decoder
        d_base_feat_size = 7  # 16 for 512 input size, and 7 for 224
        in_out_head_sr = [[32, 64, num_heads_dec[0], sr_ratio[-1]], [144, 128, num_heads_dec[1], sr_ratio[-2]],
                          [288, 320, num_heads_dec[2], sr_ratio[-3]], [512, 512, num_heads_dec[3], sr_ratio[-4]]]

        self.decoder_3 = MyDecoderLayer((d_base_feat_size, d_base_feat_size), in_out_head_sr[3],
                                        token_mlp_mode, n_class=num_classes, L_num=decoder_L_num)
        self.decoder_2 = MyDecoderLayer((d_base_feat_size * 2, d_base_feat_size * 2), in_out_head_sr[2],
                                        token_mlp_mode, n_class=num_classes, L_num=decoder_L_num)
        self.decoder_1 = MyDecoderLayer((d_base_feat_size * 4, d_base_feat_size * 4), in_out_head_sr[1],
                                        token_mlp_mode, n_class=num_classes, L_num=decoder_L_num)
        self.decoder_0 = MyDecoderLayer((d_base_feat_size * 8, d_base_feat_size * 8), in_out_head_sr[0],
                                        token_mlp_mode, n_class=num_classes, L_num=decoder_L_num, is_last=True)

    def forward(self, x):
        # ---------------Encoder-------------------------
        if x.size()[1] == 1:
            x = x.repeat(1, 3, 1, 1)

        output_enc = self.backbone(x)
        enhanced_output = self.skip_connection(output_enc)

        b, c, _, _ = output_enc[3].shape

        # ---------------Decoder-------------------------
        tmp_3 = self.decoder_3(enhanced_output[3].permute(0, 2, 3, 1).view(b, -1, c))
        tmp_2 = self.decoder_2(tmp_3, enhanced_output[2].permute(0, 2, 3, 1))
        tmp_1 = self.decoder_1(tmp_2, enhanced_output[1].permute(0, 2, 3, 1))
        tmp_0 = self.decoder_0(tmp_1, enhanced_output[0].permute(0, 2, 3, 1))

        return tmp_0


class fet_block(nn.Module):
    def __init__(self, image_size, in_dim, num_heads, patch_size=3, stride=1, padding=1, sr_ratio=1, layers=2,
                 token_mlp='mix_skip', L_num=3):
        super().__init__()
        self.patch_embed1 = OverlapPatchEmbeddings(image_size, patch_size=patch_size, stride=stride, padding=padding,
                                                   in_ch=in_dim, dim=in_dim)
        self.block1 = nn.ModuleList([
            FETBlock(in_dim, num_heads, sr_ratio, token_mlp, L_num=L_num) for _ in range(layers)])
        self.norm1 = nn.LayerNorm(in_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]
        x, H, W = self.patch_embed1(x)
        for blk in self.block1:
            x = blk(x, H, W)
        x = self.norm1(x)
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        return x


if __name__ == '__main__':
    img = torch.randn([1, 3, 256, 256]).to('cuda')
    model = UltraLight_VM_UNet().to('cuda')
    out = model(img)
    print(out.shape)
